package com.jstarcraft.ai.jsat.classifiers.svm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.calibration.BinaryScoreClassifier;
import com.jstarcraft.ai.jsat.distributions.kernels.KernelTrick;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.exceptions.UntrainedModelException;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.IndexTable;
import com.jstarcraft.ai.jsat.utils.ListUtils;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

/**
 * Implementation of the Stochastic Batch Perceptron (SBP) algorithm. Despite
 * its name, it solves the kernelized SVM problem. Because it is done
 * stochastically, it may not produce Support Vectors that the standard SVM
 * algorithm learns. It can learn at most one SV per iteration.
 * 
 * See:<br>
 * Cotter, A., Shalev-Shwartz, S.,&amp;Srebro, N. (2012). <i>The Kernelized
 * Stochastic Batch Perceptron</i>. International Conference on Machine
 * Learning. Learning. Retrieved from <a href="http://arxiv.org/abs/1204.0566">
 * here</a>
 * 
 * @author Edward Raff
 */
public class SBP extends SupportVectorLearner implements BinaryScoreClassifier, Parameterized {

    private static final long serialVersionUID = 6112916782260792833L;
    private double nu = 0.1;
    private int iterations;
    private double burnIn = 1.0 / 5.0;

    /**
     * Creates a new SBP SVM learner
     * 
     * @param kernel    the kernel to use
     * @param cacheMode the type of kernel cache to use
     */
    public SBP(KernelTrick kernel, CacheMode cacheMode, int iterations, double v) {
        super(kernel, cacheMode);
        setIterations(iterations);
        setNu(v);
    }

    /**
     * Copy constructor
     * 
     * @param other the object to copy
     */
    protected SBP(SBP other) {
        this(other.getKernel().clone(), other.getCacheMode(), other.iterations, other.nu);
        if (other.alphas != null)
            this.alphas = Arrays.copyOf(other.alphas, other.alphas.length);
    }

    @Override
    public SBP clone() {
        return new SBP(this);
    }

    /**
     * Sets the number of iterations to go through. At most one SV can be learned
     * per iteration. If more iterations are done than there are SVs, it is highly
     * likely that O(n) SVs will be used, making the model very dense. It may take
     * far fewer iterations of the algorithm than there are data points to get good
     * accuracy.
     * 
     * @param iterations the number of iterations of the algorithm to perform
     */
    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    /**
     * Returns the number of iterations the algorithm will perform
     * 
     * @return the number of iterations the algorithm will perform
     */
    public int getIterations() {
        return iterations;
    }

    /**
     * The nu parameter for this SVM is not the same as the standard nu-SVM
     * formulation, though it plays a similar role. It must be in the range (0, 1),
     * where small values indicate a linearly separable problem (in the kernel
     * space), and large values mean the problem is less separable. If the value is
     * too small for the problem, the SVM may fail to converge or produce good
     * results.
     * 
     * @param nu the value between (0, 1)
     */
    public void setNu(double nu) {
        if (Double.isNaN(nu) || nu <= 0 || nu >= 1)
            throw new IllegalArgumentException("nu must be in the range (0, 1)");
        this.nu = nu;
    }

    /**
     * Returns the nu SVM parameter
     * 
     * @return the nu SVM parameter
     */
    public double getNu() {
        return nu;
    }

    /**
     * Sets the burn in fraction. SBP averages the intermediate solutions from each
     * step as the final solution. The intermediate steps of SBP are highly
     * correlated, and the begging solutions are usually not as meaningful toward
     * the converged solution. To overcome this issue a certain fraction of the
     * iterations are not averaged into the final solution, making them the "burn
     * in" fraction. A value of 0.25 would then be ignoring the initial 25% of
     * solutions.
     * 
     * @param burnIn the ratio int [0, 1) initial solutions to ignore
     */
    public void setBurnIn(double burnIn) {
        if (Double.isNaN(burnIn) || burnIn < 0 || burnIn >= 1)
            throw new IllegalArgumentException("BurnInFraction must be in [0, 1), not " + burnIn);
        this.burnIn = burnIn;
    }

    /**
     * 
     * @return the burn in fraction
     */
    public double getBurnIn() {
        return burnIn;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (vecs == null)
            throw new UntrainedModelException("Classifier has yet to be trained");

        CategoricalResults cr = new CategoricalResults(2);

        double sum = getScore(data);

        // SVM only says yess / no, can not give a percentage
        if (sum < 0)
            cr.setProb(0, 1.0);
        else
            cr.setProb(1, 1.0);

        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return kEvalSum(dp.getNumericalValues());
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2)
            throw new FailedToFitException("SBP supports only binary classification");

        final int n = dataSet.size();
        /**
         * First index where we start summing for the average
         */
        final int T_0 = (int) Math.min((burnIn * iterations), iterations - 1);
        /*
         * Respone values
         */
        double[] C = new double[n];
        double[] CSum = new double[n];
        alphas = new double[n];
        double[] alphasSum = new double[n];

        double[] y = new double[n];
        vecs = new ArrayList<Vec>(n);
        for (int i = 0; i < n; i++) {
            y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
            vecs.add(dataSet.getDataPoint(i).getNumericalValues());
        }

        setCacheMode(getCacheMode());// Initiates the cahce

        Random rand = RandomUtil.getRandom();
        double maxKii = 0;
        for (int i = 0; i < n; i++)
            maxKii = Math.max(maxKii, kEval(i, i));

        final double eta_0 = 1 / Math.sqrt(maxKii);

        double rSqrd = 0;

        for (int t = 1; t <= iterations; t++) {
            final double eta = eta_0 / Math.sqrt(t);
            final double gamma = findGamma(C, n * nu);

            int i;
            i = sampleC(rand, n, C, gamma);

            alphas[i] += eta;
            rSqrd = updateLoop(rSqrd, eta, C, i, y, n);

            rSqrd = projectionStep(rSqrd, n, C);

            if (t >= T_0)
                for (int j = 0; j < n; j++) {
                    alphasSum[j] += alphas[j];
                    CSum[j] += C[j];
                }
        }

        // Take the averages
        for (int j = 0; j < n; j++) {
            alphas[j] = alphasSum[j] / (iterations - T_0);
            C[j] = CSum[j] / (iterations - T_0);
        }
        double gamma = findGamma(C, n * nu);
        for (int j = 0; j < n; j++)
            alphas[j] /= gamma;

        // Clean up to only the SVs
        int supportVectorCount = 0;
        for (int i = 0; i < vecs.size(); i++)
            if (alphas[i] != 0)// its a support vector
            {
                ListUtils.swap(vecs, supportVectorCount, i);
                alphas[supportVectorCount++] = alphas[i] * y[i];
            }

        vecs = new ArrayList<Vec>(vecs.subList(0, supportVectorCount));
        alphas = Arrays.copyOfRange(alphas, 0, supportVectorCount);

        it = null;
        setCacheMode(null);
        setAlphas(alphas);
    }

    private double projectionStep(double rSqrd, final int n, double[] C) {
        if (rSqrd > 1)// 1^2 = 1, so jsut use sqrd version
        {
            final double rInv = 1 / Math.sqrt(rSqrd);

            for (int j = 0; j < n; j++) {
                C[j] *= rInv;
                alphas[j] *= rInv;
            }

            rSqrd = 1;
        }
        return rSqrd;
    }

    private int sampleC(Random rand, final int n, double[] C, final double gamma) throws FailedToFitException {
        int i = 0;
        // Samply uniformly from C[i] <= gamma
        int attempts = 0;// you get 5 attempts to find one quickly
        do {
            i = rand.nextInt(n);
            attempts++;
        } while (C[i] > gamma && attempts < 5);
        if (C[i] > gamma)// find one the slow way
        {
            int candidates = 0;
            for (int j = 0; j < C.length; j++) {
                if (C[j] < gamma)
                    candidates++;
            }

            if (candidates == 0)
                throw new FailedToFitException("BUG: please report");

            int randCand = rand.nextInt(candidates);
            i = 0;
            for (int j = 0; j < C.length && i < randCand; j++)
                if (C[i] < gamma)
                    i++;
        }
        return i;
    }

    private double updateLoop(double rSqrd, final double eta, double[] C, int i, double[] y, final int n) {
        rSqrd += 2 * eta * C[i] + eta * eta * kEval(i, i);
        final double y_i = y[i];
        for (int j = 0; j < n; j++)
            C[j] += eta * y_i * y[j] * kEval(i, j);
        return rSqrd;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    private IndexTable it;

    // TODO add bias version of findGamma

    private double findGamma(double[] C, double d) {
        if (it == null)
            it = new IndexTable(C);
        else
            it.sort(C);// few will change from iteration to iteration, Java's TimSort should be able to
                       // exploit this

        double sum = 0;
        double max;
        double finalScore = 0, prevScore = 0;

        int i;
        for (i = 0; i < it.length(); i++) {
            max = C[it.index(i)];
            sum += max;

            double score = max * i - sum;
            prevScore = finalScore;
            finalScore = (d - max * i + sum) / i + max;

            if (score >= d)
                break;
        }

        return prevScore;
    }

}
